scikit-learn docs provide a nice text classification tutorial. Make sure to read it first. We'll be doing something similar to it, while taking more detailed look at classifier weights and predictions.
First, we need some data. Let's load 20 Newsgroups data, keeping only 4 categories:
In [1]:
from sklearn.datasets import fetch_20newsgroups
categories = ['alt.atheism', 'soc.religion.christian',
'comp.graphics', 'sci.med']
twenty_train = fetch_20newsgroups(
subset='train',
categories=categories,
shuffle=True,
random_state=42
)
twenty_test = fetch_20newsgroups(
subset='test',
categories=categories,
shuffle=True,
random_state=42
)
A basic text processing pipeline - bag of words features and Logistic Regression as a classifier:
In [2]:
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.linear_model import LogisticRegressionCV
from sklearn.pipeline import make_pipeline
vec = CountVectorizer()
clf = LogisticRegressionCV()
pipe = make_pipeline(vec, clf)
pipe.fit(twenty_train.data, twenty_train.target);
We're using LogisticRegressionCV here to adjust regularization parameter C automatically. It allows to compare different vectorizers - optimal C value could be different for different input features (e.g. for bigrams or for character-level input). An alternative would be to use GridSearchCV or RandomizedSearchCV.
Let's check quality of this pipeline:
In [3]:
from sklearn import metrics
def print_report(pipe):
y_test = twenty_test.target
y_pred = pipe.predict(twenty_test.data)
report = metrics.classification_report(y_test, y_pred,
target_names=twenty_test.target_names)
print(report)
print("accuracy: {:0.3f}".format(metrics.accuracy_score(y_test, y_pred)))
print_report(pipe)
Not bad. We can try other classifiers and preprocessing methods, but let's check first what the model learned using eli5.show_weights
function:
In [4]:
import eli5
eli5.show_weights(clf, top=10)
Out[4]:
The table above doesn't make any sense; the problem is that eli5 was not able to get feature and class names from the classifier object alone. We can provide feature and target names explicitly:
In [5]:
# eli5.show_weights(clf,
# feature_names=vec.get_feature_names(),
# target_names=twenty_test.target_names)
The code above works, but a better way is to provide vectorizer instead and let eli5 figure out the details automatically:
In [6]:
eli5.show_weights(clf, vec=vec, top=10,
target_names=twenty_test.target_names)
Out[6]:
This starts to make more sense. Columns are target classes. In each column there are features and their weights. Intercept (bias) feature is shown as <BIAS>
in the same table. We can inspect features and weights because we're using a bag-of-words vectorizer and a linear classifier (so there is a direct mapping between individual words and classifier coefficients). For other classifiers features can be harder to inspect.
Some features look good, but some don't. It seems model learned some names specific to a dataset (email parts, etc.) though, instead of learning topic-specific words. Let's check prediction results on an example:
In [7]:
eli5.show_prediction(clf, twenty_test.data[0], vec=vec,
target_names=twenty_test.target_names)
Out[7]:
What can be highlighted in text is highlighted in text. There is also a separate table for features which can't be highlighted in text - <BIAS>
in this case. If you hover mouse on a highlighted word it shows you a weight of this word in a title. Words are colored according to their weights.
Aha, from the highlighting above it can be seen that a classifier learned some non-interesting stuff indeed, e.g. it remembered parts of email addresses. We should probably clean the data first to make it more interesting; improving model (trying different classifiers, etc.) doesn't make sense at this point - it may just learn to leverage these email addresses better.
In practice we'd have to do cleaning yourselves; in this example 20 newsgroups dataset provides an option to remove footers and headers from the messages. Nice. Let's clean up the data and re-train a classifier.
In [8]:
twenty_train = fetch_20newsgroups(
subset='train',
categories=categories,
shuffle=True,
random_state=42,
remove=['headers', 'footers'],
)
twenty_test = fetch_20newsgroups(
subset='test',
categories=categories,
shuffle=True,
random_state=42,
remove=['headers', 'footers'],
)
vec = CountVectorizer()
clf = LogisticRegressionCV()
pipe = make_pipeline(vec, clf)
pipe.fit(twenty_train.data, twenty_train.target);
We just made the task harder and more realistic for a classifier.
In [9]:
print_report(pipe)
A great result - we just made quality worse! Does it mean pipeline is worse now? No, likely it has a better quality on unseen messages. It is evaluation which is more fair now. Inspecting features used by classifier allowed us to notice a problem with the data and made a good change, despite of numbers which told us not to do that.
Instead of removing headers and footers we could have improved evaluation setup directly, using e.g. GroupKFold from scikit-learn. Then quality of old model would have dropped, we could have removed headers/footers and see increased accuracy, so the numbers would have told us to remove headers and footers. It is not obvious how to split data though, what groups to use with GroupKFold.
So, what have the updated classifier learned? (output is less verbose because only a subset of classes is shown - see "targets" argument):
In [10]:
eli5.show_prediction(clf, twenty_test.data[0], vec=vec,
target_names=twenty_test.target_names,
targets=['sci.med'])
Out[10]:
Hm, it no longer uses email addresses, but it still doesn't look good: classifier assigns high weights to seemingly unrelated words like 'do' or 'my'. These words appear in many texts, so maybe classifier uses them as a proxy for bias. Or maybe some of them are more common in some of classes.
To help classifier we may filter out stop words:
In [11]:
vec = CountVectorizer(stop_words='english')
clf = LogisticRegressionCV()
pipe = make_pipeline(vec, clf)
pipe.fit(twenty_train.data, twenty_train.target)
print_report(pipe)
In [12]:
eli5.show_prediction(clf, twenty_test.data[0], vec=vec,
target_names=twenty_test.target_names,
targets=['sci.med'])
Out[12]:
Looks better, isn't it?
Alternatively, we can use TF*IDF scheme; it should give a somewhat similar effect.
Note that we're cross-validating LogisticRegression regularisation parameter here, like in other examples (LogisticRegressionCV, not LogisticRegression). TF*IDF values are different from word count values, so optimal C value can be different. We could draw a wrong conclusion if a classifier with fixed regularization strength is used - the chosen C value could have worked better for one kind of data.
In [13]:
from sklearn.feature_extraction.text import TfidfVectorizer
vec = TfidfVectorizer()
clf = LogisticRegressionCV()
pipe = make_pipeline(vec, clf)
pipe.fit(twenty_train.data, twenty_train.target)
print_report(pipe)
In [14]:
eli5.show_prediction(clf, twenty_test.data[0], vec=vec,
target_names=twenty_test.target_names,
targets=['sci.med'])
Out[14]:
It helped, but didn't have quite the same effect. Why not do both?
In [15]:
vec = TfidfVectorizer(stop_words='english')
clf = LogisticRegressionCV()
pipe = make_pipeline(vec, clf)
pipe.fit(twenty_train.data, twenty_train.target)
print_report(pipe)
In [16]:
eli5.show_prediction(clf, twenty_test.data[0], vec=vec,
target_names=twenty_test.target_names,
targets=['sci.med'])
Out[16]:
In [17]:
vec = TfidfVectorizer(stop_words='english', analyzer='char',
ngram_range=(3,5))
clf = LogisticRegressionCV()
pipe = make_pipeline(vec, clf)
pipe.fit(twenty_train.data, twenty_train.target)
print_report(pipe)
In [18]:
eli5.show_prediction(clf, twenty_test.data[0], vec=vec,
target_names=twenty_test.target_names)
Out[18]:
It works, but quality is a bit worse. Also, it takes ages to train.
It looks like stop_words have no effect now - in fact, this is documented in scikit-learn docs, so our stop_words='english' was useless. But at least it is now more obvious how the text looks like for a char ngram-based classifier. Grab a cup of tea and see how char_wb looks like:
In [19]:
vec = TfidfVectorizer(analyzer='char_wb', ngram_range=(3,5))
clf = LogisticRegressionCV()
pipe = make_pipeline(vec, clf)
pipe.fit(twenty_train.data, twenty_train.target)
print_report(pipe)
In [20]:
eli5.show_prediction(clf, twenty_test.data[0], vec=vec,
target_names=twenty_test.target_names)
Out[20]:
The result is similar, with some minor changes. Quality is better for unknown reason; maybe cross-word dependencies are not that important.
To check that we can try fitting word n-grams instead of char n-grams. But let's deal with efficiency first. To handle large vocabularies we can use HashingVectorizer from scikit-learn; to make training faster we can employ SGDCLassifier:
In [21]:
from sklearn.feature_extraction.text import HashingVectorizer
from sklearn.linear_model import SGDClassifier
vec = HashingVectorizer(stop_words='english', ngram_range=(1,2))
clf = SGDClassifier(n_iter=10, random_state=42)
pipe = make_pipeline(vec, clf)
pipe.fit(twenty_train.data, twenty_train.target)
print_report(pipe)
It was super-fast! We're not choosing regularization parameter using cross-validation though. Let's check what model learned:
In [22]:
eli5.show_prediction(clf, twenty_test.data[0], vec=vec,
target_names=twenty_test.target_names,
targets=['sci.med'])
Out[22]:
Result looks similar to CountVectorizer. But with HashingVectorizer we don't even have a vocabulary! Why does it work?
In [23]:
eli5.show_weights(clf, vec=vec, top=10,
target_names=twenty_test.target_names)
Out[23]:
Ok, we don't have a vocabulary, so we don't have feature names. Are we out of luck? Nope, eli5 has an answer for that: InvertableHashingVectorizer
. It can be used to get feature names for HahshingVectorizer without fitiing a huge vocabulary. It still needs some data to learn words -> hashes mapping though; we can use a random subset of data to fit it.
In [24]:
from eli5.sklearn import InvertableHashingVectorizer
import numpy as np
In [25]:
ivec = InvertableHashingVectorizer(vec)
sample_size = len(twenty_train.data) // 10
X_sample = np.random.choice(twenty_train.data, size=sample_size)
ivec.fit(X_sample);
In [26]:
eli5.show_weights(clf, vec=ivec, top=20,
target_names=twenty_test.target_names)
Out[26]:
There are collisions (hover mouse over features with "..."), and there are important features which were not seen in the random sample (FEATURE[...]), but overall it looks fine.
"rutgers edu" bigram feature is suspicious though, it looks like a part of URL.
In [27]:
rutgers_example = [x for x in twenty_train.data if 'rutgers' in x.lower()][0]
print(rutgers_example)
Yep, it looks like model learned this address instead of learning something useful.
In [28]:
eli5.show_prediction(clf, rutgers_example, vec=vec,
target_names=twenty_test.target_names,
targets=['soc.religion.christian'])
Out[28]:
Quoted text makes it too easy for model to classify some of the messages; that won't generalize to new messages. So to improve the model next step could be to process the data further, e.g. remove quoted text or replace email addresses with a special token.
You get the idea: looking at features helps to understand how classifier works. Maybe even more importantly, it helps to notice preprocessing bugs, data leaks, issues with task specification - all these nasty problems you get in a real world.